import os
import torch
import argparse

# Plot
import matplotlib.pyplot as plt
from utils import get_args_table

# Data
from data import get_data, dataset_choices

# Model
import torch.nn as nn
from varflow.flows import Flow
from varflow.distributions import StandardNormal
from varflow.transforms import AdditiveCouplingBijection, AffineCouplingBijection, ActNormBijection, Reverse, Augment
from varflow.nn.nets import MLP
from varflow.nn.layers import ElementwiseParams, LambdaLayer, scale_fn

# Eval
from varflow.utils import iwbo, iwbo_batched, iwbo_nats

# Optim
from torch.optim import Adam, Adamax

###########
## Setup ##
###########

parser = argparse.ArgumentParser()

# Data params
parser.add_argument('--dataset', type=str, default='checkerboard', choices=dataset_choices)
parser.add_argument('--train_samples', type=int, default=128*1000)
parser.add_argument('--test_samples', type=int, default=128*1000)

# Model params
parser.add_argument('--num_flows', type=int, default=4)
parser.add_argument('--actnorm', type=eval, default=False)
parser.add_argument('--affine', type=eval, default=True)
parser.add_argument('--scale_fn', type=str, default='softplus', choices={'exp', 'softplus', 'sigmoid', 'tanh_exp'})
parser.add_argument('--hidden_units', type=eval, default=[50])
parser.add_argument('--activation', type=str, default='relu', choices={'relu', 'elu', 'gelu'})
parser.add_argument('--augdim', type=int, default=2)

# Train params
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--optimizer', type=str, default='adam', choices={'adam', 'adamax'})
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--device', type=str, default='cpu')

# Eval params
parser.add_argument('--k', type=int, default=10)
parser.add_argument('--kbs', type=int, default=None)

# Plot params
parser.add_argument('--num_samples', type=int, default=128*1000)
parser.add_argument('--grid_size', type=int, default=500)
parser.add_argument('--pixels', type=int, default=1000)
parser.add_argument('--dpi', type=int, default=96)
parser.add_argument('--clim', type=float, default=0.05)
# Check the DPI of your monitor at: https://www.infobyip.com/detectmonitordpi.php

args = parser.parse_args()

torch.manual_seed(0)
if not os.path.exists('figures'): os.mkdir('figures')

##################
## Specify data ##
##################

train_loader, test_loader = get_data(args)

###################
## Specify model ##
###################

assert args.augdim % 2 == 0

D = 2 # Number of data dimensions
A = 2 + args.augdim # Number of augmented data dimensions
P = 2 if args.affine else 1 # Number of elementwise parameters

transforms = [Augment(StandardNormal((args.augdim,)), x_size=D)]
for _ in range(args.num_flows):
    net = nn.Sequential(MLP(A//2, P*A//2,
                            hidden_units=args.hidden_units,
                            activation=args.activation),
                        ElementwiseParams(P))
    if args.affine: transforms.append(AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn)))
    else:           transforms.append(AdditiveCouplingBijection(net))
    if args.actnorm: transforms.append(ActNormBijection(D))
    transforms.append(Reverse(A))
transforms.pop()


model = Flow(base_dist=StandardNormal((A,)),
             transforms=transforms).to(args.device)

#######################
## Specify optimizer ##
#######################

if args.optimizer == 'adam':
    optimizer = Adam(model.parameters(), lr=args.lr)
elif args.optimizer == 'adamax':
    optimizer = Adamax(model.parameters(), lr=args.lr)

##############
## Training ##
##############

print('Training...')
for epoch in range(args.epochs):
    loss_sum = 0.0
    for i, x in enumerate(train_loader):
        optimizer.zero_grad()
        loss = -model.log_prob(x.to(args.device)).mean()
        loss.backward()
        optimizer.step()
        loss_sum += loss.detach().cpu().item()
        print('Epoch: {}/{}, Iter: {}/{}, Nats: {:.3f}'.format(epoch+1, args.epochs, i+1, len(train_loader), loss_sum/(i+1)), end='\r')
    print('')
final_train_bpd = loss_sum / len(train_loader)

#############
## Testing ##
#############

print('Testing...')
with torch.no_grad():
    loss_sum = 0.0
    for i, x in enumerate(test_loader):
        loss = iwbo_nats(model, x.to(args.device), k=args.k, kbs=args.kbs)
        loss_sum += loss.detach().cpu().item()
        print('Iter: {}/{}, Nats: {:.3f}'.format(i+1, len(test_loader), loss_sum/(i+1)), end='\r')
    print('')
final_test_nats = loss_sum / len(test_loader)

##############
## Sampling ##
##############

print('Sampling...')
if args.dataset == 'face_einstein':
    bounds = [[0, 1], [0, 1]]
else:
    bounds = [[-4, 4], [-4, 4]]

# Plot samples
samples = model.sample(args.num_samples)
samples = samples.detach().cpu().numpy()
plt.figure(figsize=(args.pixels/args.dpi, args.pixels/args.dpi), dpi=args.dpi)
plt.hist2d(samples[...,0], samples[...,1], bins=256, range=bounds)
plt.xlim(bounds[0])
plt.ylim(bounds[1])
plt.axis('off')
plt.savefig('figures/{}_aug_flow_samples.png'.format(args.dataset), bbox_inches = 'tight', pad_inches = 0)

# Plot density
xv, yv = torch.meshgrid([torch.linspace(bounds[0][0], bounds[0][1], args.grid_size), torch.linspace(bounds[1][0], bounds[1][1], args.grid_size)])
x = torch.cat([xv.reshape(-1,1), yv.reshape(-1,1)], dim=-1).to(args.device)
with torch.no_grad():
    # logprobs = model.log_prob(x)
    if args.kbs: logprobs = iwbo_batched(model=model, x=x, k=args.k, kbs=args.kbs)
    else:        logprobs = iwbo(model=model, x=x, k=args.k)
    # logprobs = iwbo_nats(model=model, x=x, k=args.k, kbs=args.kbs)
plt.figure(figsize=(args.pixels/args.dpi, args.pixels/args.dpi), dpi=args.dpi)
plt.pcolormesh(xv, yv, logprobs.exp().reshape(xv.shape))
plt.xlim(bounds[0])
plt.ylim(bounds[1])
plt.axis('off')
print('Range:', logprobs.exp().min().numpy(), logprobs.exp().max().numpy())
print('Limits:', 0.0, args.clim)
plt.clim(0,args.clim)
plt.savefig('figures/{}_aug_flow_density.png'.format(args.dataset), bbox_inches = 'tight', pad_inches = 0)

# Save log-likelihood
with open('figures/{}_aug_flow_loglik.txt'.format(args.dataset), 'w') as f:
    f.write(str(final_test_nats))

# Save args
args_table = get_args_table(vars(args))
with open('figures/{}_aug_flow_args.txt'.format(args.dataset), 'w') as f:
    f.write(str(args_table))

# Display plots
plt.show()
